{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e2c01a17",
   "metadata": {},
   "source": [
    "### Search your Google Drive knowledge base with fully local processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df7609cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.3\u001b[0m\n",
      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "%pip install -qU \"langchain==0.3.27\" \"langchain-core<1.0.0,>=0.3.78\" \"langchain-text-splitters<1.0.0,>=0.3.9\" langchain_ollama langchain_chroma langchain_community google-auth google-auth-oauthlib google-auth-httplib2 google-api-python-client python-docx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "144bdf7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import logging\n",
    "import os\n",
    "import re\n",
    "import sys\n",
    "import hashlib\n",
    "from pathlib import Path\n",
    "from enum import StrEnum\n",
    "from typing import Iterable, Optional\n",
    "\n",
    "import gradio as gr\n",
    "from langchain_core.documents import Document\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter\n",
    "from langchain_ollama import OllamaEmbeddings, ChatOllama\n",
    "from langchain.storage import InMemoryStore\n",
    "from langchain_chroma import Chroma\n",
    "from langchain_community.document_loaders import TextLoader\n",
    "from google.oauth2.credentials import Credentials\n",
    "from google_auth_oauthlib.flow import InstalledAppFlow\n",
    "from google.auth.transport.requests import Request\n",
    "from googleapiclient.discovery import build\n",
    "from googleapiclient.http import MediaIoBaseDownload\n",
    "from googleapiclient.errors import HttpError\n",
    "from docx import Document as DocxDocument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfdb143d",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = logging.getLogger('drive_sage')\n",
    "logger.setLevel(logging.DEBUG)\n",
    "\n",
    "if not logger.handlers:\n",
    "    handler = logging.StreamHandler(sys.stdout)\n",
    "    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n",
    "    handler.setFormatter(formatter)\n",
    "    logger.addHandler(handler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41df43aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "SCOPES = ['https://www.googleapis.com/auth/drive.readonly']\n",
    "APP_ROOT = Path.cwd()\n",
    "DATA_DIR = APP_ROOT / '.drive_sage'\n",
    "DOWNLOAD_DIR = DATA_DIR / 'downloads'\n",
    "VECTORSTORE_DIR = DATA_DIR / 'chroma'\n",
    "TOKEN_PATH = DATA_DIR / 'token.json'\n",
    "MANIFEST_PATH = DATA_DIR / 'manifest.json'\n",
    "CLIENT_SECRET_FILE = APP_ROOT / 'client_secret_202216035337-4qson0c08g71u8uuihv6v46arv64nhvg.apps.googleusercontent.com.json'\n",
    "\n",
    "for path in (DATA_DIR, DOWNLOAD_DIR, VECTORSTORE_DIR):\n",
    "    path.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "FILE_TYPE_OPTIONS = {\n",
    "    'txt': {\n",
    "        'label': '.txt - Plain text',\n",
    "        'extensions': ['.txt'],\n",
    "        'mime_types': ['text/plain'],\n",
    "    },\n",
    "    'md': {\n",
    "        'label': '.md - Markdown',\n",
    "        'extensions': ['.md'],\n",
    "        'mime_types': ['text/markdown', 'text/plain'],\n",
    "    },\n",
    "    'docx': {\n",
    "        'label': '.docx - Word (OpenXML)',\n",
    "        'extensions': ['.docx'],\n",
    "        'mime_types': ['application/vnd.openxmlformats-officedocument.wordprocessingml.document'],\n",
    "    },\n",
    "    'doc': {\n",
    "        'label': '.doc - Word 97-2003',\n",
    "        'extensions': ['.doc'],\n",
    "        'mime_types': ['application/msword', 'application/vnd.ms-word.document.macroenabled.12'],\n",
    "    },\n",
    "    'gdoc': {\n",
    "        'label': 'Google Docs (exported)',\n",
    "        'extensions': ['.docx'],\n",
    "        'mime_types': ['application/vnd.google-apps.document'],\n",
    "    },\n",
    "}\n",
    "\n",
    "FILE_TYPE_LABEL_TO_KEY = {config['label']: key for key, config in FILE_TYPE_OPTIONS.items()}\n",
    "DEFAULT_FILE_TYPE_KEYS = ['txt', 'md', 'docx', 'doc', 'gdoc']\n",
    "DEFAULT_FILE_TYPE_LABELS = [FILE_TYPE_OPTIONS[key]['label'] for key in DEFAULT_FILE_TYPE_KEYS]\n",
    "\n",
    "MIME_TYPE_TO_EXTENSION = {}\n",
    "for key, config in FILE_TYPE_OPTIONS.items():\n",
    "    extension = config['extensions'][0]\n",
    "    for mime in config['mime_types']:\n",
    "        MIME_TYPE_TO_EXTENSION[mime] = extension\n",
    "\n",
    "GOOGLE_EXPORT_FORMATS = {\n",
    "    'application/vnd.google-apps.document': (\n",
    "        'application/vnd.openxmlformats-officedocument.wordprocessingml.document',\n",
    "        '.docx'\n",
    "    ),\n",
    "}\n",
    "\n",
    "SIMILARITY_DISTANCE_MAX = float(os.getenv('DRIVE_SAGE_DISTANCE_MAX', '1.2'))\n",
    "MAX_CONTEXT_SNIPPET_CHARS = 1200\n",
    "HASH_BLOCK_SIZE = 65536\n",
    "EMBED_MODEL = os.getenv('DRIVE_SAGE_EMBED_MODEL', 'nomic-embed-text')\n",
    "CHAT_MODEL = os.getenv('DRIVE_SAGE_CHAT_MODEL', 'llama3.1:latest')\n",
    "\n",
    "CUSTOM_CSS = \"\"\"\n",
    "#chat-column {\n",
    "    height: 80vh;\n",
    "}\n",
    "#chat-column > div {\n",
    "    height: 100%;\n",
    "}\n",
    "#chat-column .gradio-chatbot,\n",
    "#chat-column .gradio-chat-interface,\n",
    "#chat-column .gradio-chatinterface {\n",
    "    height: 100%;\n",
    "}\n",
    "#chat-output {\n",
    "    height: 100%;\n",
    "}\n",
    "#chat-output .overflow-y-auto {\n",
    "    max-height: 100% !important;\n",
    "}\n",
    "#chat-output .h-full {\n",
    "    height: 100% !important;\n",
    "}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "225a921a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_drive_service():\n",
    "    creds = None\n",
    "    if TOKEN_PATH.exists():\n",
    "        try:\n",
    "            creds = Credentials.from_authorized_user_file(str(TOKEN_PATH), SCOPES)\n",
    "        except Exception as exc:\n",
    "            logger.warning('Failed to load cached credentials: %s', exc)\n",
    "            TOKEN_PATH.unlink(missing_ok=True)\n",
    "            creds = None\n",
    "\n",
    "    if not creds or not creds.valid:\n",
    "        if creds and creds.expired and creds.refresh_token:\n",
    "            try:\n",
    "                creds.refresh(Request())\n",
    "            except Exception as exc:\n",
    "                logger.warning('Refreshing credentials failed: %s', exc)\n",
    "                creds = None\n",
    "\n",
    "        if not creds or not creds.valid:\n",
    "            if not CLIENT_SECRET_FILE.exists():\n",
    "                raise FileNotFoundError(\n",
    "                    'client_secret.json not found. Download it from Google Cloud Console and place it next to this notebook.'\n",
    "                )\n",
    "            flow = InstalledAppFlow.from_client_secrets_file(str(CLIENT_SECRET_FILE), SCOPES)\n",
    "            creds = flow.run_local_server(port=0)\n",
    "\n",
    "        with TOKEN_PATH.open('w', encoding='utf-8') as token_file:\n",
    "            token_file.write(creds.to_json())\n",
    "            \n",
    "    return build('drive', 'v3', credentials=creds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0acb8ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_manifest() -> dict:\n",
    "    if MANIFEST_PATH.exists():\n",
    "        try:\n",
    "            with MANIFEST_PATH.open('r', encoding='utf-8') as fp:\n",
    "                raw = json.load(fp)\n",
    "            if isinstance(raw, dict):\n",
    "                normalized: dict[str, dict] = {}\n",
    "                for file_id, entry in raw.items():\n",
    "                    if isinstance(entry, dict):\n",
    "                        normalized[file_id] = entry\n",
    "                    else:\n",
    "                        normalized[file_id] = {'modified': str(entry)}\n",
    "                return normalized\n",
    "        except json.JSONDecodeError:\n",
    "            logger.warning('Manifest file is corrupted; resetting cache.')\n",
    "    return {}\n",
    "\n",
    "def save_manifest(manifest: dict) -> None:\n",
    "    with MANIFEST_PATH.open('w', encoding='utf-8') as fp:\n",
    "        json.dump(manifest, fp, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43098d19",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Metadata(StrEnum):\n",
    "    ID = 'id'\n",
    "    SOURCE = 'source'\n",
    "    PARENT_ID = 'parent_id'\n",
    "    FILE_TYPE = 'file_type'\n",
    "    TITLE = 'title'\n",
    "    MODIFIED = 'modified'\n",
    "\n",
    "def metadata_key(key: Metadata) -> str:\n",
    "    return key.value\n",
    "\n",
    "embeddings = OllamaEmbeddings(model=EMBED_MODEL)\n",
    "\n",
    "try:\n",
    "    vectorstore = Chroma(\n",
    "        collection_name='drive_sage',\n",
    "        embedding_function=embeddings,\n",
    "    )\n",
    "except Exception as exc:\n",
    "    logger.exception('Failed to initialize in-memory Chroma vector store')\n",
    "    raise RuntimeError('Unable to initialize Chroma vector store without persistence.') from exc\n",
    "\n",
    "docstore = InMemoryStore()\n",
    "model = ChatOllama(model=CHAT_MODEL)\n",
    "\n",
    "DEFAULT_TEXT_SPLITTER = RecursiveCharacterTextSplitter(\n",
    "    chunk_size=1000,\n",
    "    chunk_overlap=150,\n",
    "    separators=['\\n\\n', '\\n', ' ', '']\n",
    ")\n",
    "MARKDOWN_HEADERS = [('#', 'Header 1'), ('##', 'Header 2'), ('###', 'Header 3')]\n",
    "MARKDOWN_SPLITTER = MarkdownHeaderTextSplitter(headers_to_split_on=MARKDOWN_HEADERS, strip_headers=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "116be5f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def safe_filename(name: str, max_length: int = 120) -> str:\n",
    "    sanitized = re.sub(r'[^A-Za-z0-9._-]', '_', name)\n",
    "    sanitized = sanitized.strip('._') or 'untitled'\n",
    "    return sanitized[:max_length]\n",
    "\n",
    "def determine_extension(metadata: dict) -> str:\n",
    "    mime_type = metadata.get('mimeType', '')\n",
    "    name = metadata.get('name')\n",
    "    if name and Path(name).suffix:\n",
    "        return Path(name).suffix.lower()\n",
    "    if mime_type in GOOGLE_EXPORT_FORMATS:\n",
    "        return GOOGLE_EXPORT_FORMATS[mime_type][1]\n",
    "    return MIME_TYPE_TO_EXTENSION.get(mime_type, '.txt')\n",
    "\n",
    "def cached_file_path(metadata: dict) -> Path:\n",
    "    file_id = metadata.get('id', 'unknown')\n",
    "    extension = determine_extension(metadata)\n",
    "    safe_name = safe_filename(Path(metadata.get('name', file_id)).stem)\n",
    "    return DOWNLOAD_DIR / f'{safe_name}_{file_id}{extension}'\n",
    "\n",
    "def hash_file(path: Path) -> str:\n",
    "    digest = hashlib.sha1()\n",
    "    with path.open('rb') as fh:\n",
    "        while True:\n",
    "            block = fh.read(HASH_BLOCK_SIZE)\n",
    "            if not block:\n",
    "                break\n",
    "            digest.update(block)\n",
    "    return digest.hexdigest()\n",
    "\n",
    "def manifest_version(entry: dict | str | None) -> Optional[str]:\n",
    "    if entry is None:\n",
    "        return None\n",
    "    if isinstance(entry, str):\n",
    "        return entry\n",
    "    if isinstance(entry, dict):\n",
    "        return entry.get('modified')\n",
    "    return None\n",
    "\n",
    "def update_manifest_entry(manifest: dict, *, file_id: str, modified: str, path: Path, mime_type: str, name: str) -> None:\n",
    "    manifest[file_id] = {\n",
    "        'modified': modified,\n",
    "        'path': str(path),\n",
    "        'mimeType': mime_type,\n",
    "        'name': name,\n",
    "        'file_type': Path(path).suffix.lower(),\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5fe85b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def list_drive_text_files(service, folder_id: Optional[str], allowed_mime_types: list[str], limit: Optional[int]) -> list[dict]:\n",
    "    query_parts = [\"trashed = false\"]\n",
    "    mime_types = allowed_mime_types or list(MIME_TYPE_TO_EXTENSION.keys())\n",
    "    mime_clause = ' or '.join([f\"mimeType = '{mime}'\" for mime in mime_types])\n",
    "    query_parts.append(f'({mime_clause})')\n",
    "    if folder_id:\n",
    "        query_parts.append(f\"'{folder_id}' in parents\")\n",
    "    query = ' and '.join(query_parts)\n",
    "\n",
    "    files: list[dict] = []\n",
    "    page_token: Optional[str] = None\n",
    "\n",
    "    while True:\n",
    "        page_size = min(100, limit - len(files)) if limit else 100\n",
    "        if page_size <= 0:\n",
    "            break\n",
    "        try:\n",
    "            response = service.files().list(\n",
    "                q=query,\n",
    "                spaces='drive',\n",
    "                fields='nextPageToken, files(id, name, mimeType, modifiedTime)',\n",
    "                orderBy='modifiedTime desc',\n",
    "                pageToken=page_token,\n",
    "                pageSize=page_size,\n",
    "            ).execute()\n",
    "        except HttpError as exc:\n",
    "            raise RuntimeError(f'Google Drive API error: {exc}') from exc\n",
    "\n",
    "        batch = response.get('files', [])\n",
    "        files.extend(batch)\n",
    "        if limit and len(files) >= limit:\n",
    "            return files[:limit]\n",
    "        page_token = response.get('nextPageToken')\n",
    "        if not page_token:\n",
    "            break\n",
    "    return files\n",
    "\n",
    "def download_drive_file(service, metadata: dict, manifest: dict) -> Path:\n",
    "    file_id = metadata['id']\n",
    "    mime_type = metadata.get('mimeType', '')\n",
    "    cache_path = cached_file_path(metadata)\n",
    "    export_mime = None\n",
    "    if mime_type in GOOGLE_EXPORT_FORMATS:\n",
    "        export_mime, extension = GOOGLE_EXPORT_FORMATS[mime_type]\n",
    "        if cache_path.suffix.lower() != extension:\n",
    "            cache_path = cache_path.with_suffix(extension)\n",
    "\n",
    "\n",
    "    request = (\n",
    "        service.files().export_media(fileId=file_id, mimeType=export_mime)\n",
    "        if export_mime\n",
    "        else service.files().get_media(fileId=file_id)\n",
    "    )\n",
    "\n",
    "    logger.debug('Downloading %s (%s) -> %s', metadata.get('name', file_id), file_id, cache_path)\n",
    "    with cache_path.open('wb') as fh:\n",
    "        downloader = MediaIoBaseDownload(fh, request)\n",
    "        done = False\n",
    "        while not done:\n",
    "            status, done = downloader.next_chunk()\n",
    "            if status:\n",
    "                logger.debug('Download progress %.0f%%', status.progress() * 100)\n",
    "\n",
    "    update_manifest_entry(\n",
    "        manifest,\n",
    "        file_id=file_id,\n",
    "        modified=metadata.get('modifiedTime', ''),\n",
    "        path=cache_path,\n",
    "        mime_type=mime_type,\n",
    "        name=metadata.get('name', cache_path.name),\n",
    "    )\n",
    "    return cache_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27f50b9b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_docx_text(path: Path) -> str:\n",
    "    doc = DocxDocument(str(path))\n",
    "    lines = [paragraph.text.strip() for paragraph in doc.paragraphs if paragraph.text.strip()]\n",
    "    return '\\n'.join(lines)\n",
    "\n",
    "def load_documents(\n",
    "    path: Path,\n",
    "    *,\n",
    "    source_id: Optional[str] = None,\n",
    "    file_type: Optional[str] = None,\n",
    "    modified: Optional[str] = None,\n",
    "    display_name: Optional[str] = None,\n",
    " ) -> list[Document]:\n",
    "    suffix = (file_type or path.suffix or '.txt').lower()\n",
    "    try:\n",
    "        if suffix in {'.txt', '.md'}:\n",
    "            loader = TextLoader(str(path), encoding='utf-8')\n",
    "            documents = loader.load()\n",
    "        elif suffix == '.docx':\n",
    "            documents = [Document(page_content=extract_docx_text(path), metadata={'source': str(path)})]\n",
    "        else:\n",
    "            raise ValueError(f'Unsupported file type: {suffix}')\n",
    "    except UnicodeDecodeError as exc:\n",
    "        raise ValueError(f'Failed to read {path}: {exc}') from exc\n",
    "\n",
    "    base_metadata = {\n",
    "        metadata_key(Metadata.SOURCE): str(path),\n",
    "        metadata_key(Metadata.FILE_TYPE): suffix,\n",
    "        metadata_key(Metadata.TITLE): display_name or path.name,\n",
    "    }\n",
    "    if source_id:\n",
    "        base_metadata[metadata_key(Metadata.ID)] = source_id\n",
    "    if modified:\n",
    "        base_metadata[metadata_key(Metadata.MODIFIED)] = modified\n",
    "\n",
    "    cleaned: list[Document] = []\n",
    "    for doc in documents:\n",
    "        content = doc.page_content.strip()\n",
    "        if not content:\n",
    "            continue\n",
    "        merged_metadata = {**doc.metadata, **base_metadata}\n",
    "        doc.page_content = content\n",
    "        doc.metadata = merged_metadata\n",
    "        cleaned.append(doc)\n",
    "    return cleaned\n",
    "\n",
    "def preprocess(documents: Iterable[Document]) -> list[Document]:\n",
    "    return [doc for doc in documents if doc.page_content]\n",
    "\n",
    "def chunk_documents(doc: Document) -> list[Document]:\n",
    "    parent_id = doc.metadata.get(metadata_key(Metadata.ID))\n",
    "    if not parent_id:\n",
    "        raise ValueError('Document is missing a stable identifier for chunking.')\n",
    "\n",
    "    if doc.metadata.get(metadata_key(Metadata.FILE_TYPE)) == '.md':\n",
    "        markdown_docs = MARKDOWN_SPLITTER.split_text(doc.page_content)\n",
    "        seed_docs = [\n",
    "            Document(page_content=section.page_content, metadata={**doc.metadata, **section.metadata})\n",
    "            for section in markdown_docs\n",
    "        ]\n",
    "    else:\n",
    "        seed_docs = [doc]\n",
    "\n",
    "    chunks = DEFAULT_TEXT_SPLITTER.split_documents(seed_docs)\n",
    "    for idx, chunk in enumerate(chunks):\n",
    "        chunk.metadata[metadata_key(Metadata.PARENT_ID)] = parent_id\n",
    "        chunk.metadata[metadata_key(Metadata.ID)] = f'{parent_id}::chunk-{idx:04d}'\n",
    "        chunk.metadata.setdefault(metadata_key(Metadata.SOURCE), doc.metadata.get(metadata_key(Metadata.SOURCE)))\n",
    "        chunk.metadata.setdefault(metadata_key(Metadata.TITLE), doc.metadata.get(metadata_key(Metadata.TITLE)))\n",
    "    return chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f135f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sync_drive_and_index(folder_id=None, selected_types=None, file_limit=None, _state: bool = False, progress=gr.Progress(track_tqdm=False)):\n",
    "    folder = (folder_id or '').strip() or None\n",
    "\n",
    "    selections = selected_types if selected_types is not None else DEFAULT_FILE_TYPE_LABELS\n",
    "    if not isinstance(selections, (list, tuple)):\n",
    "        selections = [selections]\n",
    "    selections = list(selections)\n",
    "\n",
    "    if len(selections) == 0:\n",
    "        yield 'Select at least one file type before syncing.', False\n",
    "        return\n",
    "\n",
    "    chosen_keys: list[str] = []\n",
    "    for item in selections:\n",
    "        key = FILE_TYPE_LABEL_TO_KEY.get(item, item)\n",
    "        if key in FILE_TYPE_OPTIONS:\n",
    "            chosen_keys.append(key)\n",
    "\n",
    "    if not chosen_keys:\n",
    "        yield 'Select at least one file type before syncing.', False\n",
    "        return\n",
    "\n",
    "    allowed_mime_types = sorted({mime for key in chosen_keys for mime in FILE_TYPE_OPTIONS[key]['mime_types']})\n",
    "\n",
    "    limit: Optional[int] = None\n",
    "    limit_warning: Optional[str] = None\n",
    "    if file_limit not in (None, '', 0):\n",
    "        try:\n",
    "            parsed_limit = int(file_limit)\n",
    "            if parsed_limit > 0:\n",
    "                limit = parsed_limit\n",
    "            else:\n",
    "                raise ValueError\n",
    "        except (TypeError, ValueError):\n",
    "            limit_warning = 'File limit must be a positive integer. Syncing all matching files instead.'\n",
    "\n",
    "    log_lines: list[str] = []\n",
    "\n",
    "    def push(message: str) -> str:\n",
    "        log_lines.append(message)\n",
    "        return '\\n'.join(log_lines)\n",
    "\n",
    "    if limit_warning:\n",
    "        logger.warning(limit_warning)\n",
    "        yield push(limit_warning), False\n",
    "\n",
    "    progress(0, 'Authorizing Google Drive access...')\n",
    "    yield push('Authorizing Google Drive access...'), False\n",
    "\n",
    "    try:\n",
    "        service = build_drive_service()\n",
    "    except FileNotFoundError as exc:\n",
    "        error_msg = f'Error: {exc}'\n",
    "        logger.error(error_msg)\n",
    "        yield push(error_msg), False\n",
    "        return\n",
    "    except Exception as exc:\n",
    "        logger.exception('Drive authorization failed')\n",
    "        error_msg = f'Error authenticating with Google Drive: {exc}'\n",
    "        yield push(error_msg), False\n",
    "        return\n",
    "\n",
    "    list_message = 'Listing documents' + (f' (limit {limit})' if limit else '') + '...'\n",
    "    progress(0, list_message)\n",
    "    yield push(list_message), False\n",
    "\n",
    "    try:\n",
    "        files = list_drive_text_files(service, folder, allowed_mime_types, limit)\n",
    "    except Exception as exc:\n",
    "        logger.exception('Listing Drive files failed')\n",
    "        error_msg = f'Error listing Google Drive files: {exc}'\n",
    "        yield push(error_msg), False\n",
    "        return\n",
    "\n",
    "    total = len(files)\n",
    "    if total == 0:\n",
    "        info = 'No documents matching the selected types were found in Google Drive.'\n",
    "        yield push(info), True\n",
    "        return\n",
    "\n",
    "    manifest = load_manifest()\n",
    "    downloaded_count = 0\n",
    "\n",
    "    for index, metadata in enumerate(files, start=1):\n",
    "        file_id = metadata['id']\n",
    "        name = metadata.get('name', file_id)\n",
    "        remote_version = metadata.get('modifiedTime', '')\n",
    "        manifest_entry = manifest.get(file_id)\n",
    "        cache_path = cached_file_path(metadata)\n",
    "        if isinstance(manifest_entry, dict) and manifest_entry.get('path'):\n",
    "            cache_path = Path(manifest_entry['path'])\n",
    "        cached_version = manifest_version(manifest_entry)\n",
    "\n",
    "        if cached_version == remote_version and cache_path.exists():\n",
    "            message = f\"{index}/{total} Skipping cached file: {name} -> {cache_path}\"\n",
    "            progress(index / total, message)\n",
    "            yield push(message), False\n",
    "            continue\n",
    "\n",
    "        download_message = f\"{index}/{total} Downloading {name} -> {cache_path}\"\n",
    "        progress(max((index - 0.5) / total, 0), download_message)\n",
    "        yield push(download_message), False\n",
    "\n",
    "        try:\n",
    "            downloaded_path = download_drive_file(service, metadata, manifest)\n",
    "            index_message = f\"{index}/{total} Indexing {downloaded_path.name}\"\n",
    "            progress(index / total, index_message)\n",
    "            yield push(index_message), False\n",
    "            index_document(\n",
    "                downloaded_path,\n",
    "                source_id=file_id,\n",
    "                file_type=downloaded_path.suffix,\n",
    "                modified=remote_version,\n",
    "                display_name=name,\n",
    "                manifest=manifest,\n",
    "            )\n",
    "            downloaded_count += 1\n",
    "        except Exception as exc:\n",
    "            error_message = f\"{index}/{total} Failed to sync {name}: {exc}\"\n",
    "            logger.exception(error_message)\n",
    "            progress(index / total, error_message)\n",
    "            yield push(error_message), False\n",
    "\n",
    "    if downloaded_count > 0:\n",
    "        save_manifest(manifest)\n",
    "        summary = f'Indexed {downloaded_count} new document(s) from Google Drive.'\n",
    "    else:\n",
    "        summary = 'Google Drive is already in sync.'\n",
    "\n",
    "    progress(1, summary)\n",
    "    yield push(summary), True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e2f176b",
   "metadata": {},
   "source": [
    "## RAG Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20ad0e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "def persist_vectorstore(_store) -> None:\n",
    "    \"\"\"In-memory mode: Chroma client does not persist between sessions.\"\"\"\n",
    "    return\n",
    "\n",
    "\n",
    "def index_document(\n",
    "    file_path: Path | str,\n",
    "    *,\n",
    "    source_id: Optional[str] = None,\n",
    "    file_type: Optional[str] = None,\n",
    "    modified: Optional[str] = None,\n",
    "    display_name: Optional[str] = None,\n",
    "    manifest: Optional[dict] = None,\n",
    " ) -> tuple[str, int]:\n",
    "    path = Path(file_path)\n",
    "    path = path.expanduser().resolve()\n",
    "    resolved_id = source_id or f'local::{hash_file(path)}'\n",
    "    documents = load_documents(\n",
    "        path,\n",
    "        source_id=resolved_id,\n",
    "        file_type=file_type,\n",
    "        modified=modified,\n",
    "        display_name=display_name,\n",
    "    )\n",
    "    documents = preprocess(documents)\n",
    "    if not documents:\n",
    "        logger.warning('No readable content found in %s; skipping.', path)\n",
    "        return resolved_id, 0\n",
    "\n",
    "    total_chunks = 0\n",
    "    for doc in documents:\n",
    "        doc_id = doc.metadata.get(metadata_key(Metadata.ID), resolved_id)\n",
    "        doc.metadata[metadata_key(Metadata.ID)] = doc_id\n",
    "        vectorstore.delete(where={metadata_key(Metadata.PARENT_ID): doc_id})\n",
    "        chunks = chunk_documents(doc)\n",
    "        if not chunks:\n",
    "            continue\n",
    "        vectorstore.add_documents(chunks)\n",
    "        docstore.mset([(doc_id, doc)])\n",
    "        total_chunks += len(chunks)\n",
    "\n",
    "    persist_vectorstore(vectorstore)\n",
    "    if manifest is not None and not source_id:\n",
    "        update_manifest_entry(\n",
    "            manifest,\n",
    "            file_id=resolved_id,\n",
    "            modified=hash_file(path),\n",
    "            path=path,\n",
    "            mime_type=file_type or Path(path).suffix or '.txt',\n",
    "            name=display_name or path.name,\n",
    "        )\n",
    "    return resolved_id, total_chunks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a90db6ee",
   "metadata": {},
   "source": [
    "### LLM Interaction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e15e99",
   "metadata": {},
   "outputs": [],
   "source": [
    "def retrieve_context(query: str, *, top_k: int = 8, distance_threshold: Optional[float] = SIMILARITY_DISTANCE_MAX):\n",
    "    results_with_scores = vectorstore.similarity_search_with_score(query, k=top_k)\n",
    "    logger.info(f'Matching records: {len(results_with_scores)}')\n",
    "\n",
    "    filtered: list[tuple[Document, float]] = []\n",
    "    for doc, score in results_with_scores:\n",
    "        if score is None:\n",
    "            continue\n",
    "        score_value = float(score)\n",
    "        print(f'DEBUG: Retrieved doc source={doc.metadata.get(metadata_key(Metadata.SOURCE))} distance={score_value}')\n",
    "        if distance_threshold is not None and score_value > distance_threshold:\n",
    "            logger.debug(\n",
    "                'Skipping %s with distance %.4f (above threshold %.4f)',\n",
    "                doc.metadata.get(metadata_key(Metadata.SOURCE)),\n",
    "                score_value,\n",
    "                distance_threshold,\n",
    "            )\n",
    "            continue\n",
    "        filtered.append((doc, score_value))\n",
    "\n",
    "    if not filtered:\n",
    "        return []\n",
    "\n",
    "    for doc, score_value in filtered:\n",
    "        parent_id = doc.metadata.get(metadata_key(Metadata.PARENT_ID))\n",
    "        if parent_id:\n",
    "            parent_doc = docstore.mget([parent_id])[0]\n",
    "            if parent_doc and parent_doc.page_content:\n",
    "                logger.debug(\n",
    "                    'Parent preview (%s | %.3f): %s',\n",
    "                    doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown'),\n",
    "                    score_value,\n",
    "                    parent_doc.page_content[:400].replace('\\n', ' '),\n",
    "                )\n",
    "\n",
    "    return filtered\n",
    "\n",
    "\n",
    "def build_prompt_sections(relevant_docs: list[tuple[Document, float]]) -> str:\n",
    "    sections: list[str] = []\n",
    "    for idx, (doc, score) in enumerate(relevant_docs, start=1):\n",
    "        source = doc.metadata.get(metadata_key(Metadata.SOURCE), 'unknown')\n",
    "        snippet = doc.page_content.strip()[:MAX_CONTEXT_SNIPPET_CHARS]\n",
    "        section = (\n",
    "            f'[{idx}] Source: {source}\\n'\n",
    "            f'Distance: {score:.3f}\\n'\n",
    "            f'Content:\\n{snippet}'\n",
    "        )\n",
    "        sections.append(section)\n",
    "    return '\\n\\n'.join(sections)\n",
    "\n",
    "\n",
    "def ask(message, history):\n",
    "    relevant_docs = retrieve_context(message)\n",
    "    if not relevant_docs:\n",
    "        yield \"I don't have enough information in the synced documents to answer that yet. Please sync additional files or adjust the filters.\"\n",
    "        return\n",
    "\n",
    "    context = build_prompt_sections(relevant_docs)\n",
    "    prompt = f'''\n",
    "    You are a retrieval-augmented assistant. Use ONLY the facts provided in the context to answer the user.\n",
    "    If the context does not contain the answer, reply exactly: \"I don't have enough information in the synced documents to answer that yet. Please sync additional files.\"\n",
    "    \n",
    "    Context:\\n{context}\n",
    "    '''\n",
    "\n",
    "    messages = [\n",
    "        ('system', prompt),\n",
    "        ('user', message)\n",
    "    ]\n",
    "\n",
    "    stream = model.stream(messages)\n",
    "    response_text = ''\n",
    "\n",
    "    for chunk in stream:\n",
    "        response_text += chunk.content or ''\n",
    "        if not response_text:\n",
    "            continue\n",
    "\n",
    "        yield response_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3e632dc-9e87-4510-9fcd-aa699c27e82b",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## Gradio UI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3d68a74",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(message, history, sync_ready):\n",
    "    if message is None:\n",
    "        return ''\n",
    "\n",
    "    text_input = message.get('text', '')\n",
    "    files_uploaded = message.get('files', [])\n",
    "    latest_file_path = Path(files_uploaded[-1]) if files_uploaded else None\n",
    "    if latest_file_path:\n",
    "        manifest = load_manifest()\n",
    "        doc_id, chunk_count = index_document(\n",
    "            latest_file_path,\n",
    "            file_type=latest_file_path.suffix,\n",
    "            display_name=latest_file_path.name,\n",
    "            manifest=manifest,\n",
    "        )\n",
    "        save_manifest(manifest)\n",
    "        logger.info('Indexed upload %s as %s with %s chunk(s)', latest_file_path, doc_id, chunk_count)\n",
    "        if not text_input:\n",
    "            yield f'Indexed document from upload ({chunk_count} chunk(s)).'\n",
    "            return\n",
    "\n",
    "    if not text_input:\n",
    "        return ''\n",
    "\n",
    "    if not sync_ready and not files_uploaded:\n",
    "        yield 'Sync Google Drive before chatting or upload a document first.'\n",
    "        return\n",
    "\n",
    "    for chunk in ask(text_input, history):\n",
    "        yield chunk\n",
    "\n",
    "title = \"Drive Sage\"\n",
    "with gr.Blocks(title=title, fill_height=True, css=CUSTOM_CSS) as ui:\n",
    "    gr.Markdown(f'# {title}')\n",
    "    gr.Markdown('## Search your Google Drive knowledge base with fully local processing.')\n",
    "    sync_state = gr.State(False)\n",
    "\n",
    "    with gr.Row():\n",
    "        with gr.Column(scale=3, elem_id='chat-column'):\n",
    "            gr.ChatInterface(\n",
    "                fn=chat,\n",
    "                chatbot=gr.Chatbot(height='80vh', elem_id='chat-output'),\n",
    "                type='messages',\n",
    "                textbox=gr.MultimodalTextbox(\n",
    "                    file_types=['text', '.txt', '.md'],\n",
    "                    autofocus=True,\n",
    "                    elem_id='chat-input',\n",
    "                ),\n",
    "                additional_inputs=[sync_state],\n",
    "            )\n",
    "        with gr.Column(scale=2, min_width=320):\n",
    "            gr.Markdown('### Google Drive Sync')\n",
    "            drive_folder = gr.Textbox(\n",
    "                label='Folder ID (optional)',\n",
    "                placeholder='Leave blank to scan My Drive root',\n",
    "            )\n",
    "            file_types = gr.CheckboxGroup(\n",
    "                label='File types to sync',\n",
    "                choices=[config['label'] for config in FILE_TYPE_OPTIONS.values()],\n",
    "                value=DEFAULT_FILE_TYPE_LABELS,\n",
    "            )\n",
    "            file_limit = gr.Number(\n",
    "                label='Max files to sync (leave blank for all)',\n",
    "                value=20,\n",
    "            )\n",
    "            sync_btn = gr.Button('Sync Google Drive')\n",
    "            sync_status = gr.Markdown('No sync performed yet.')\n",
    "\n",
    "            sync_btn.click(\n",
    "                sync_drive_and_index,\n",
    "                inputs=[drive_folder, file_types, file_limit, sync_state],\n",
    "                outputs=[sync_status, sync_state],\n",
    "            )\n",
    "\n",
    "ui.launch(debug=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
